<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>TensorSpace - Training LeNet</title>
    <meta name="author" content="syt123450 / https://github.com/syt123450">
    <meta name="description" content="Demo to show how to use TensorSpace to show training process">

    <script src="../lib/jquery.min.js"></script>
    <script src="../lib/three.min.js"></script>
    <script src="../lib/stats.min.js"></script>
    <script src="../lib/tween.min.js"></script>
    <script src="../lib/TrackballControls.js"></script>
    <script src="../lib/tf.min.js"></script>
    <script src="./data/mnist_data.js"></script>
    <script src="../../dist/tensorspace.js"></script>

    <style>

        html, body {
            margin: 0;
            padding: 0;
            width: 100%;
            height: 100%;
        }

        #container {
            width: 100%;
            height: 100%;
        }

    </style>

</head>
<body>

<div id="container"></div>

<script>

    // predict data "5"
    let data5;
    // Mnist data set;
    let data;

	let [ tfjsLoadModel, tfjsTrainingModel ] = constructorTfjsModel();
    let tspModel = constructorTspModel();

	tspModel.load( {

		type: "live",
		modelHandler: tfjsLoadModel

	} );

	tspModel.init( function() {

		$.ajax( {

			url: "./data/5.json",
			type: 'GET',
			async: true,
			dataType: 'json',
			success: function ( data ) {

                data5 = data;
				loadAndTrain();

			}

		} );


	} );

	async function loadAndTrain() {

		await load();
		await train();

	}

	function constructorTfjsModel() {

		let input = tf.input( {

			shape: [ 28, 28, 1 ],
			name: "mnistInput"
		} );

		let padding2d = tf.layers.zeroPadding2d( {

			padding: [ 2, 2 ],
			name: "myPadding"
		} );

		let conv1 = tf.layers.conv2d( {

			kernelSize: 5,
			filters: 6,
			strides: 1,
			activation: 'relu',
			kernelInitializer: 'VarianceScaling',
			name: "myConv1"

		} );

		let maxPool1 = tf.layers.maxPooling2d( {

			poolSize: [ 2, 2 ],
			strides: [ 2, 2 ],
			name: "myMaxPooling1"
		} );

		let conv2 = tf.layers.conv2d( {

			kernelSize: 5,
			filters: 16,
			strides: 1,
			activation: 'relu',
			kernelInitializer: 'VarianceScaling',
			name: "myConv2"

		} );

		let maxPool2 = tf.layers.maxPooling2d( {

			poolSize: [ 2, 2 ],
			strides: [ 2, 2 ],
			name: "myMaxPooling2"

		} );

		let flatten = tf.layers.flatten( {

			name: "myFlatten"

		} );

		let dense1 = tf.layers.dense( {

			units: 120,
			kernelInitializer: 'VarianceScaling',
			activation: 'relu',
			name: "myDense1"

		} );

		let dense2 = tf.layers.dense( {

			units: 84,
			kernelInitializer: 'VarianceScaling',
			activation: 'relu',
			name: "myDense2"

		} );

		let dense3 = tf.layers.dense( {

			units: 10,
			kernelInitializer: 'VarianceScaling',
			activation: 'softmax',
			name: "myDense3"

		} );

		let paddingOutput = padding2d.apply( input );
		let conv1Output = conv1.apply( paddingOutput );
		let maxPool1Output = maxPool1.apply( conv1Output );
		let conv2Output = conv2.apply( maxPool1Output );
		let maxPool2Output = maxPool2.apply( conv2Output );
		let flattenOutput = flatten.apply( maxPool2Output );
		let dense1Output = dense1.apply( flattenOutput );
		let dense2Output = dense2.apply( dense1Output );
		let dense3Output = dense3.apply( dense2Output );

		let tfjsLoadModel = tf.model( {

			inputs: input,
			outputs: [ paddingOutput, conv1Output, maxPool1Output, conv2Output, maxPool2Output, dense1Output, dense2Output, dense3Output ]

		} );

		let tfjsTrainingModel = tf.model( {

			inputs: input,
			outputs: dense3Output

		} );

		tfjsTrainingModel.compile( {

			optimizer: tf.train.adam( 0.0001 ),
			loss: 'categoricalCrossentropy',
			metrics: [ 'accuracy' ],

		} );

		return [ tfjsLoadModel, tfjsTrainingModel ];

    }

    function constructorTspModel() {

		let modelContainer = document.getElementById( "container" );

		let tspModel = new TSP.models.Sequential( modelContainer, {

            animeTime: 200,
			stats: true

		} );

		tspModel.add( new TSP.layers.GreyscaleInput() );

		tspModel.add( new TSP.layers.Padding2d( ) );

		tspModel.add( new TSP.layers.Conv2d( {

			initStatus: "open"

		} ) );

		tspModel.add( new TSP.layers.Pooling2d() );

		tspModel.add( new TSP.layers.Conv2d() );

		tspModel.add( new TSP.layers.Pooling2d() );

		tspModel.add( new TSP.layers.Dense() );

		tspModel.add( new TSP.layers.Dense() );

		tspModel.add( new TSP.layers.Output1d( {

			outputs: [ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" ],
			initStatus: "open"

		} ) );

		return tspModel;

    }

	async function load() {

		data = new MnistData();
		await data.load();

	}

	async function train() {

		const TRAIN_BATCH_SIZE = 10;
		const BATCH_PER_EPOCH = 10;
		const TRAIN_EPOCHS = 100;

		await tfjsTrainingModel.fitDataset(data.trainDataset.batch(TRAIN_BATCH_SIZE), {
			batchesPerEpoch: BATCH_PER_EPOCH,
			epochs: TRAIN_EPOCHS,
			callbacks: {
                onEpochEnd: async (epoch, logs) => {
					const accuracy = logs[ 'acc' ];
					console.log( accuracy );
					tspModel.predict( data5 );
                }
			}
		});

	}

</script>

</body>
</html>